#include<cstdio>
#include<cctype>
const long long maxn=998244353;
using namespace std;
inline long long get();
long long n,m;
long long ans;
int main()
{
	freopen("bpmp.in","r",stdin);
	freopen("bpmp.out","w",stdout);
	n=get();m=get();
	ans=((m-1)+m*(n-1))%maxn;
	printf("%lld",ans);
	return 0;
}
inline long long get()
{
 	long long t=0,jud=1;char c=getchar();
	while(!isdigit(c))
	{
 	 	if(c=='-')jud=-1;
		c=getchar();
	}
	while(isdigit(c))
	{
	 	t=(t<<3)+(t<<1)+c-'0';
		c=getchar();
	}
	return t*jud;
}

